{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "iNeRF-needs-integration-with-public-jaxnerf",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cveXXkCALyDn"
      },
      "source": [
        "Copyright 2021 Google LLC.\n",
        "SPDX-License-Identifier: Apache-2.0\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xy7R3OHSK6cL"
      },
      "source": [
        "# iNeRF implementation.\n",
        "\n",
        "Implementation of \"iNeRF: Inverting Neural Radiance Fields for Pose Estimation\"\n",
        "Website: https://yenchenlin.me/inerf/\n",
        "\n",
        "Note: this implementation needs to be integrated with the public version of jaxnerf: https://github.com/google-research/google-research/tree/master/jaxnerf"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dyBQcKfwLqoP"
      },
      "source": [
        "# TODO(yenchenl): add pip installs."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FYF1FzIxI56D"
      },
      "source": [
        "import colabtools\n",
        "import functools\n",
        "import gc\n",
        "import time\n",
        "from absl import app\n",
        "from absl import flags\n",
        "from flax import jax_utils\n",
        "from flax import nn\n",
        "from flax import optim\n",
        "from flax.metrics import tensorboard\n",
        "from flax.training import checkpoints\n",
        "import getpass\n",
        "import jax\n",
        "from jax import config\n",
        "from jax import random\n",
        "from jax import numpy as jnp\n",
        "from jax import grad, jit, vmap, value_and_grad\n",
        "from jax.example_libraries import optimizers\n",
        "import matplotlib.pyplot as plt\n",
        "import PIL\n",
        "from PIL import Image as PilImage\n",
        "import numpy as np\n",
        "from scipy.spatial.transform import Rotation as R\n",
        "from six.moves import reload_module\n",
        "import yaml"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HuxtyC3qJC0G"
      },
      "source": [
        "from jax.lib import xla_bridge\n",
        "print(xla_bridge.get_backend().platform, jax.device_count())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WFXqSepmJPHm"
      },
      "source": [
        "# TODO(yenchenl): jaxnerf imports may need attention\n",
        "from jaxnerf.nerf import datasets\n",
        "from jaxnerf.nerf import model_utils\n",
        "from jaxnerf.nerf import models\n",
        "from jaxnerf.nerf import utils"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7tWHeShAJWmI"
      },
      "source": [
        "# TODO(yenchenl): update paths\n",
        "flags.DEFINE_string(\n",
        "    \"train_dir\",\n",
        "    \"/path_to/jaxnerf_models/blender/lego/\",\n",
        "    \"Experiment path.\")\n",
        "flags.mark_flag_as_required(\"train_dir\")\n",
        "flags.DEFINE_string(\n",
        "    \"data_dir\",\n",
        "    \"/path_to/datasets/nerf/nerf_synthetic/lego\",\n",
        "    \"Data path.\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QgKdyjpzJoCu"
      },
      "source": [
        "FLAGS = flags.FLAGS\n",
        "flags.DEFINE_integer(\"n_gpus\", 1, \"Number of gpus per train worker.\")\n",
        "flags.DEFINE_integer(\"n_gpus_eval\", 1, \"Number of gpus per eval worker.\")\n",
        "flags.mark_flag_as_required(\"data_dir\")\n",
        "flags.DEFINE_enum(\"config\", \"blender\", [\"blender\",\"llff\",],\n",
        "                  \"Choice of the reuse-able full configuration.\")\n",
        "flags.DEFINE_bool(\"is_train\", True, \"The job is in the training mode.\")\n",
        "flags.DEFINE_bool(\"use_tpu\", False, \"Whether to use tpu for training.\")\n",
        "flags.DEFINE_bool(\"use_tpu_eval\", False, \"Whether to use tpu for evaluation.\")\n",
        "flags.DEFINE_integer(\"render_every\", 0,\n",
        "                     \"the interval in optimization steps between rendering\"\n",
        "                     \"a validation example. 0 is recommended if using\"\n",
        "                     \"parallel train and eval jobs.\")\n",
        "flags.DEFINE_integer(\n",
        "    \"chunk\", None, \"the size of chunks for evaluation inferences, set to\"\n",
        "    \"the value that fits your GPU/TPU memory.\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NAbxctpaJqvQ"
      },
      "source": [
        "flags.DEFINE_enum(\"dataset\", \"blender\",\n",
        "                  list(k for k in datasets.dataset_dict.keys()),\n",
        "                  \"The type of dataset feed to nerf.\")\n",
        "flags.DEFINE_bool(\"image_batching\", False,\n",
        "                  \"sample rays in a batch from different images.\")\n",
        "flags.DEFINE_bool(\n",
        "    \"white_bkgd\", True, \"using white color as default background.\"\n",
        "    \"(used in the blender dataset only)\")\n",
        "flags.DEFINE_integer(\"batch_size\", 1024,\n",
        "                      \"the number of rays in a mini-batch (for training).\")\n",
        "flags.DEFINE_integer(\n",
        "    \"factor\", 4, \"the downsample factor of images, 0 for no downsample.\")\n",
        "flags.DEFINE_bool(\"spherify\", False, \"set for spherical 360 scenes.\")\n",
        "flags.DEFINE_bool(\n",
        "    \"render_path\", False, \"render generated path if set true.\"\n",
        "    \"(used in the llff dataset only)\")\n",
        "flags.DEFINE_integer(\n",
        "    \"llffhold\", 8, \"will take every 1/N images as LLFF test set.\"\n",
        "    \"(used in the llff dataset only)\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dX3Koco3JuAu"
      },
      "source": [
        "# Model Flags\n",
        "flags.DEFINE_enum(\"model\", \"nerf\", list(k for k in models.model_dict.keys()),\n",
        "                  \"name of model to use.\")\n",
        "flags.DEFINE_float(\"near\", 2., \"near clip of volumetric rendering.\")\n",
        "flags.DEFINE_float(\"far\", 6., \"far clip of volumentric rendering.\")\n",
        "flags.DEFINE_integer(\"net_depth\", 8, \"depth of the first part of MLP.\")\n",
        "flags.DEFINE_integer(\"net_width\", 256, \"width of the first part of MLP.\")\n",
        "flags.DEFINE_integer(\"net_depth_condition\", 1,\n",
        "                      \"depth of the second part of MLP.\")\n",
        "flags.DEFINE_integer(\"net_width_condition\", 128,\n",
        "                      \"width of the second part of MLP.\")\n",
        "flags.DEFINE_enum(\"activation\", \"relu\", [\"relu\",],\n",
        "                  \"activation function used in MLP.\")\n",
        "flags.DEFINE_integer(\n",
        "    \"skip_layer\", 4, \"add a skip connection to the output vector of every\"\n",
        "    \"skip_layer layers.\")\n",
        "flags.DEFINE_integer(\"alpha_channel\", 1, \"the number of alpha channels.\")\n",
        "flags.DEFINE_integer(\"rgb_channel\", 3, \"the number of rgb channels.\")\n",
        "flags.DEFINE_bool(\"randomized\", True, \"use randomized stratified sampling.\")\n",
        "flags.DEFINE_integer(\"deg_point\", 10,\n",
        "                      \"Degree of positional encoding for points.\")\n",
        "flags.DEFINE_integer(\"deg_view\", 4,\n",
        "                      \"degree of positional encoding for viewdirs.\")\n",
        "flags.DEFINE_integer(\"n_samples\", 64, \"the number of samples on each ray.\")\n",
        "flags.DEFINE_integer(\"n_fine_samples\", 128,\n",
        "                      \"the number of samples on each ray for the fine model.\")\n",
        "flags.DEFINE_bool(\"use_viewdirs\", True, \"use view directions as a condition.\")\n",
        "flags.DEFINE_float(\n",
        "    \"noise_std\", None, \"std dev of noise added to regularize sigma output.\"\n",
        "    \"(used in the llff dataset only)\")\n",
        "flags.DEFINE_bool(\"lindisp\", False,\n",
        "                  \"sampling linearly in disparity rather than depth.\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bdavhWFmJwFF"
      },
      "source": [
        "# Train Flags\n",
        "flags.DEFINE_float(\"lr\", 5e-4, \"Learning rate for training.\")\n",
        "flags.DEFINE_integer(\"lr_decay\", 500,\n",
        "                      \"the number of steps (in 1000s) for exponential\"\n",
        "                      \"learning rate decay.\")\n",
        "flags.DEFINE_integer(\"max_steps\", 1000000,\n",
        "                      \"the number of optimization steps.\")\n",
        "flags.DEFINE_integer(\"save_every\", 10000,\n",
        "                      \"the number of steps to save a checkpoint.\")\n",
        "flags.DEFINE_integer(\"gc_every\", 10000,\n",
        "                      \"the number of steps to run python garbage collection.\")\n",
        "\n",
        "# No randomization in eval!\n",
        "flags.DEFINE_bool(\"randomized\", False, \"Whether stochastic or not.\")\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nEn48jKfJyS4"
      },
      "source": [
        "def compute_pose_error(T_estobject_cam, T_gtobject_cam):\n",
        "  \"\"\"Compute scalars for rotation and translation error between two poses.\"\"\"\n",
        "  T_estobject_gtobject = T_estobject_cam @ np.linalg.inv(T_gtobject_cam)\n",
        "  rotation_error = np.arccos((np.trace(T_estobject_gtobject[:3, :3]) - 1) / 2)\n",
        "  translation_error = np.linalg.norm(T_estobject_gtobject[:3, -1])\n",
        "  return rotation_error * 180 / np.pi, translation_error"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w1NYk_QWJ0NY"
      },
      "source": [
        "## Load dataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MxwGwJdXJzaV"
      },
      "source": [
        "# TODO(yenchenl): blender config\n",
        "blender_cfg = yaml.load(Open('/path_to/nerf/blender.yaml'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XI8m31-SJ7PH"
      },
      "source": [
        "rng = random.PRNGKey(20200823)\n",
        "# Shift the numpy random seed by host_id() to shuffle data loaded by different\n",
        "# hosts.\n",
        "np.random.seed(20201473 + jax.host_id())\n",
        "\n",
        "if FLAGS.config is not None:\n",
        "  FLAGS.__dict__.update(blender_cfg)\n",
        "if FLAGS.batch_size % jax.device_count() != 0:\n",
        "  raise ValueError(\"Batch size must be divisible by the number of devices.\")\n",
        "dataset = datasets.get_dataset(\"test\", FLAGS)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16OiALmJKAJn"
      },
      "source": [
        "# Load pre-trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A4T-R7khKBri"
      },
      "source": [
        "rng, key = random.split(rng)\n",
        "init_model, init_state = models.get_model(key, FLAGS)\n",
        "dummy_optimizer_def = optim.Adam(FLAGS.lr)\n",
        "dummy_optimizer = dummy_optimizer_def.create(init_model)\n",
        "state = model_utils.TrainState(step=0, optimizer=dummy_optimizer,\n",
        "                               model_state=init_state)\n",
        "del init_model, init_state"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NgZ8GtmqKEgv"
      },
      "source": [
        "state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)\n",
        "nerf_model = state.optimizer.target"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sMGoXxlPKG8G"
      },
      "source": [
        "idx = 0\n",
        "test_image = dataset.images[idx]\n",
        "test_pixels = test_image.reshape([dataset.resolution, 3])\n",
        "test_pose = dataset.camtoworlds[idx]\n",
        "print(f\"Pixels/pose shapes: {test_pixels.shape}, {test_pose.shape}\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mj_CQswdKHgP"
      },
      "source": [
        "# Set the perturbation."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fvRxkrzPKIpu"
      },
      "source": [
        "perturbation = jax.numpy.zeros((4, 4))\n",
        "pred_pose = np.array(test_pose)\n",
        "\n",
        "USE_ROTATION_PERTURBATION = True\n",
        "USE_TRANSLATION_PERTURBATION = True\n",
        "\n",
        "if USE_ROTATION_PERTURBATION:\n",
        "  magnitude = 30.0  # Set the magnitue of rotation perturbation (degree).\n",
        "  magnitude_rad = magnitude / 180.0 * np.pi\n",
        "  direction = np.random.randn(3)\n",
        "  translation_magnitude = np.linalg.norm(direction)\n",
        "  eps = 1e-6\n",
        "  if translation_magnitude < eps:  # Prevents divide-by-0.\n",
        "    translation_magnitude = eps\n",
        "  direction = direction / translation_magnitude\n",
        "  perturbed_rotvec = direction * magnitude_rad\n",
        "  pred_rot_mat = R.from_rotvec(perturbed_rotvec).as_matrix()\n",
        "  delta = np.eye(4)\n",
        "  delta[:3, :3] = pred_rot_mat\n",
        "  pred_pose = delta @ pred_pose\n",
        "if USE_TRANSLATION_PERTURBATION:\n",
        "  magnitude = 0.05  # Set the magnitue of translation perturbation along xyz.\n",
        "  perturbation = jax.ops.index_add(\n",
        "      perturbation, jax.ops.index[:3, -1], magnitude)\n",
        "  pred_pose = pred_pose + perturbation\n",
        "\n",
        "\n",
        "print(f\"Initial pose error: {compute_pose_error(pred_pose, test_pose)}\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZYtFziU-KLUX"
      },
      "source": [
        "pred_pose_init = pred_pose * 1.0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-oiv3kl6KOfH"
      },
      "source": [
        "def RPtoSE3(R: jnp.ndarray, p: jnp.ndarray) -> np.ndarray:\n",
        "  \"\"\"Rotation and translation to homogeneous transform.\n",
        "\n",
        "  Args:\n",
        "    R: (3, 3) An orthonormal rotation matrix.\n",
        "    p: (3,) A 3-vector representing an offset.\n",
        "\n",
        "  Returns:\n",
        "    X: (4, 4) The homogeneous transformation matrix described by rotating by R\n",
        "      and translating by p.\n",
        "  \"\"\"\n",
        "  p = jnp.reshape(p, (3, 1))\n",
        "  return jnp.block([[R, p], [jnp.array([[0.0, 0.0, 0.0, 1.0]])]])\n",
        "\n",
        "def DecomposeScrew(V: np.ndarray):\n",
        "  \"\"\"Decompose a screw V into a normalized axis and a magnitude.\n",
        "\n",
        "  Args:\n",
        "    V: (6,) A spatial vector describing a screw motion.\n",
        "\n",
        "  Returns:\n",
        "    S: (6,) A unit screw axis.\n",
        "    theta: An angle of rotation such that S * theta = V.\n",
        "  \"\"\"\n",
        "  w, v = jnp.split(V, 2)\n",
        "  w_is_zero = jnp.allclose(w, jnp.zeros_like(w))\n",
        "  v_is_zero = jnp.allclose(v, jnp.zeros_like(v))\n",
        "  both_zero = w_is_zero * v_is_zero\n",
        "\n",
        "  dtheta = jnp.where(\n",
        "      both_zero, 0.0,\n",
        "      jnp.where(1 - w_is_zero, jnp.linalg.norm(w), jnp.linalg.norm(v)))\n",
        "  S = jnp.where(both_zero, V, V / dtheta)\n",
        "\n",
        "  return (S, dtheta)\n",
        "\n",
        "def Skew(w: jnp.ndarray) -> jnp.ndarray:\n",
        "  \"\"\"Build a skew matrix (\"cross product matrix\") for vector w.\n",
        "\n",
        "  Modern Robotics Eqn 3.30.\n",
        "\n",
        "  Args:\n",
        "    w: (3,) A 3-vector\n",
        "\n",
        "  Returns:\n",
        "    W: (3, 3) A skew matrix such that W @ v == w x v\n",
        "  \"\"\"\n",
        "  w = jnp.reshape(w, (3))\n",
        "  return jnp.array([[0.0, -w[2], w[1]],\\\n",
        "                   [w[2], 0.0, -w[0]],\\\n",
        "                   [-w[1], w[0], 0.0]])\n",
        "\n",
        "def ExpSO3(w: jnp.ndarray, theta: float) -> np.ndarray:\n",
        "  \"\"\"Exponential map from Lie algebra so3 to Lie group SO3.\n",
        "\n",
        "  Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.\n",
        "\n",
        "  Args:\n",
        "    w: (3,) An axis of rotation.\n",
        "    theta: An angle of rotation.\n",
        "\n",
        "  Returns:\n",
        "    R: (3, 3) An orthonormal rotation matrix representing a rotation of\n",
        "      magnitude theta about axis w.\n",
        "  \"\"\"\n",
        "  W = Skew(w)\n",
        "  return jnp.eye(3) + jnp.sin(theta) * W + (1.0 - jnp.cos(theta)) * W @ W\n",
        "\n",
        "def ExpSE3(S: jnp.ndarray, theta: float) -> np.ndarray:\n",
        "  \"\"\"Exponential map from Lie algebra so3 to Lie group SO3.\n",
        "\n",
        "  Modern Robotics Eqn 3.88.\n",
        "\n",
        "  Args:\n",
        "    S: (6,) A screw axis of motion.\n",
        "    theta: Magnitude of motion.\n",
        "\n",
        "  Returns:\n",
        "    a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating\n",
        "      motion of magnitude theta about S for one second.\n",
        "  \"\"\"\n",
        "  w, v = jnp.split(S, 2)\n",
        "  W = Skew(w)\n",
        "  R = ExpSO3(w, theta)\n",
        "  p = (theta * jnp.eye(3) + (1.0 - jnp.cos(theta)) * W +\n",
        "       (theta - jnp.sin(theta)) * W @ W) @ v\n",
        "  return RPtoSE3(R, p)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FzASDdNaKgD4"
      },
      "source": [
        "# iNeRF training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FqHHhCYkKhW0"
      },
      "source": [
        "def train_step_exp(screw_delta, test_pixels, hwf, batch_size, nerf_model):\n",
        "  \"\"\"\n",
        "  screw_delta (6,): screw of delta pose, relative to initial pose\n",
        "  test_pixels (H*W, 3): ground truth image's pixels.\n",
        "  hwf (3): image height, width, and focal length.\n",
        "  batch_size: number of rays and pixels to sample.\n",
        "  \"\"\"\n",
        "  # rng_key, key_0, key_1 = random.split(rng_key, 3)\n",
        "  rng_key = random.PRNGKey(20200823)\n",
        "  rng_key, key_0, key_1 = random.split(rng_key, 3)\n",
        "\n",
        "  def loss_fn(screw_delta):\n",
        "    \"\"\"screw_delta is a (6,)\"\"\"\n",
        "\n",
        "    # pred_pose_delta is a (4,4) matrix, SE3, relative to pred_pose_init\n",
        "    pred_pose_delta = ExpSE3(*DecomposeScrew(screw_delta))\n",
        "\n",
        "    # pred_pose is the full new estimated pose.\n",
        "    pred_pose = pred_pose_delta @  pred_pose_init\n",
        "\n",
        "    resolution = test_pixels.shape[0]\n",
        "    h, w, f = hwf\n",
        "    x, y = jnp.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking\n",
        "        jnp.arange(w),  # X-Axis (columns)\n",
        "        jnp.arange(h),  # Y-Axis (rows)\n",
        "        indexing=\"xy\")\n",
        "    dirs = jnp.stack([(x - w * 0.5) / f,\n",
        "                      -(y - h * 0.5) / f, -jnp.ones_like(x)],\n",
        "                    axis=-1)\n",
        "    rays_d = ((dirs[None, ..., None, :] * pred_pose[None, None, :3, :3]).sum(axis=-1))\n",
        "    rays_o = jnp.broadcast_to(pred_pose[None, None, :3, -1],\n",
        "                              list(rays_d.shape))\n",
        "    rays = jnp.concatenate([rays_o, rays_d], axis=-1)\n",
        "    rays = rays.reshape([resolution, rays.shape[-1]])\n",
        "\n",
        "    # Sample rays.\n",
        "    ray_sample_indices = np.random.randint(0, resolution, (batch_size,))\n",
        "    \n",
        "    batch_pixels = jnp.array(test_pixels[ray_sample_indices][None, :, :])\n",
        "    batch_rays = rays[ray_sample_indices][None, :, :]\n",
        "    batch = {'pixels': batch_pixels, 'rays': batch_rays}\n",
        "    model_outputs = nerf_model(key_0, key_1, batch_rays[0])\n",
        "    rgb = model_outputs[-1][0]\n",
        "    # MSE\n",
        "    loss = ((rgb - batch[\"pixels\"][0][..., :3])**2).mean()\n",
        "    return loss\n",
        "\n",
        "  # Forward.\n",
        "  grad_pose = jax.value_and_grad(loss_fn)\n",
        "  loss, grad = grad_pose(screw_delta)\n",
        "  return loss, grad"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Donsz1VOKknx"
      },
      "source": [
        "# Start over by re-initializing the initial relative pose.\n",
        "delta_init = np.random.randn(6) * 1e-6\n",
        "screw_delta = delta_init\n",
        "print(screw_delta)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWnRKrKhKmYg"
      },
      "source": [
        "# This will just re-initialize the optimizier, but not the current guess.\n",
        "\n",
        "initial_learning_rate = 1e-2\n",
        "decay_steps = 100\n",
        "decay_rate = 0.6\n",
        "exp_schedule = optimizers.exponential_decay(initial_learning_rate, decay_steps, decay_rate)\n",
        "step = []\n",
        "rate = []\n",
        "for i in range(1000):\n",
        "  step.append(i)\n",
        "  rate.append(exp_schedule(i))\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "plt.title('learning rate')\n",
        "plt.plot(step, rate)\n",
        "plt.yscale(\"log\")\n",
        "plt.show()\n",
        "\n",
        "opt_init, opt_update, get_params = optimizers.adam(step_size=exp_schedule)\n",
        "opt_state = opt_init(screw_delta)\n",
        "print(screw_delta)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2X4mF58hKoo3"
      },
      "source": [
        "# Inference loop.\n",
        "hwf = (dataset.h, dataset.w, dataset.focal)\n",
        "inference_batch_size = 2048\n",
        "n_iters = 300\n",
        "\n",
        "pred_poses = []\n",
        "R_errors = []\n",
        "t_errors = []\n",
        "losses = []\n",
        "for i in range(n_iters+1):\n",
        "  screw_delta = get_params(opt_state)\n",
        "  pred_pose_delta = ExpSE3(*DecomposeScrew(screw_delta))\n",
        "\n",
        "  pred_pose = pred_pose_delta @ pred_pose_init\n",
        "  loss, grad = train_step_exp(screw_delta, test_pixels, hwf,\n",
        "                              inference_batch_size, nerf_model)\n",
        "  opt_state = opt_update(i, grad, opt_state)\n",
        "\n",
        "  losses.append(loss)\n",
        "  pred_poses.append(np.array(pred_pose))\n",
        "  R_error, t_error = compute_pose_error(np.array(pred_pose), test_pose)\n",
        "  R_errors.append(R_error)\n",
        "  t_errors.append(t_error)\n",
        "  if i % 50 == 0:\n",
        "    print(f\"{i}/{n_iters} iterations ...\")\n",
        "    print(f\"loss: {loss} | R error: {R_error} | t error: {t_error}\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nzv4547EKqx4"
      },
      "source": [
        "print(len(losses))\n",
        "\n",
        "fig, axes = plt.subplots(3, 1)\n",
        "fig.tight_layout()\n",
        "\n",
        "axes[0].set_title('MSE Loss')\n",
        "axes[0].plot(range(n_iters+1), losses)\n",
        "axes[1].set_title('Rotation Error')\n",
        "axes[1].plot(range(n_iters+1), R_errors, color='r')\n",
        "axes[2].set_title('Translation Error')\n",
        "axes[2].plot(range(n_iters+1), t_errors, color='c')\n",
        "\n",
        "# To make these each log scale on y axes\n",
        "[ax.set_yscale('log') for ax in axes]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "InyyL6WUKrvx"
      },
      "source": [
        "# Show video."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FnJfGKopKspK"
      },
      "source": [
        "render_fn = jax.pmap(\n",
        "    # Note rng_keys are useless in eval mode since there's no randomness.\n",
        "    # pylint: disable=g-long-lambda\n",
        "    lambda key_0, key_1, model, rays: jax.lax.all_gather(\n",
        "        model(key_0, key_1, rays), axis_name=\"batch\"),\n",
        "    in_axes=(None, None, None, 0),  # Only distribute the data input.\n",
        "    donate_argnums=3,\n",
        "    axis_name=\"batch\",\n",
        ")\n",
        "\n",
        "render_fn_jit = jit(render_fn)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0_Estqr0Ku_P"
      },
      "source": [
        "def get_batch(pred_pose, test_pixels, hwf, batch_size):\n",
        "    resolution = test_pixels.shape[0]\n",
        "    h, w, f = hwf\n",
        "    x, y = jnp.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking\n",
        "        jnp.arange(w),  # X-Axis (columns)\n",
        "        jnp.arange(h),  # Y-Axis (rows)\n",
        "        indexing=\"xy\")\n",
        "    dirs = np.stack([(x - w * 0.5) / f,\n",
        "                      -(y - h * 0.5) / f, -jnp.ones_like(x)],\n",
        "                    axis=-1)\n",
        "    rays_d = ((dirs[None, ..., None, :] * pred_pose[None, None, :3, :3]).sum(axis=-1))\n",
        "    rays_o = jnp.broadcast_to(pred_pose[None, None, :3, -1],\n",
        "                              list(rays_d.shape))\n",
        "    rays = jnp.concatenate([rays_o, rays_d], axis=-1)[0]\n",
        "\n",
        "    batch = {'rays': rays}\n",
        "    return batch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DoaR69SkKwph"
      },
      "source": [
        "n_frames = 20\n",
        "images = []\n",
        "n_iters_reduced = 200  # set this different from n_iters, so can focus on the part where the most happens.\n",
        "for idx in range(0, n_iters_reduced+1, n_iters_reduced//n_frames):\n",
        "  print(\"Rendering, \", idx)\n",
        "  batch = get_batch(pred_poses[idx], test_pixels, hwf, batch_size)\n",
        "  pred_color, pred_disp, pred_acc = utils.render_image(\n",
        "      state, batch, render_fn_jit, rng, chunk=8192)\n",
        "  images.append(pred_color)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0kfATXPYK2RH"
      },
      "source": [
        "def save_animation(images, test_image):\n",
        "  pil_ims = [\n",
        "      PilImage.fromarray(\n",
        "          (np.clip(np.array(im), 0.0, 1.0) * 255.0).astype(np.uint8))\n",
        "      for im in images\n",
        "  ]\n",
        "  test_im = PilImage.fromarray(\n",
        "          (np.clip(test_image, 0.0, 1.0) * 255.0).astype(np.uint8))\n",
        "  pil_ims = [PIL.Image.blend(im, test_im, 0.5) for im in pil_ims]\n",
        "  pil_ims[0].save(\n",
        "      '/tmp/optimization_animation.gif',\n",
        "      save_all=True,\n",
        "      append_images=pil_ims[1:],\n",
        "      duration=200,\n",
        "      loop=0)\n",
        "  colabtools.publish.image('/tmp/optimization_animation.gif')\n",
        "\n",
        "\n",
        "save_animation(images, test_image)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}